from enum import StrEnum
from typing import IO, Any
from urllib.parse import urlparse

import aioboto3
import structlog
from types_boto3_batch.client import BatchClient
from types_boto3_ec2.client import EC2Client
from types_boto3_ecs.client import ECSClient
from types_boto3_s3.client import S3Client
from types_boto3_secretsmanager.client import SecretsManagerClient

from skyvern.config import settings

LOG = structlog.get_logger()


# We only include the storage classes that we want to use in our application.
class S3StorageClass(StrEnum):
    STANDARD = "STANDARD"
    # REDUCED_REDUNDANCY = "REDUCED_REDUNDANCY"
    # INTELLIGENT_TIERING = "INTELLIGENT_TIERING"
    ONEZONE_IA = "ONEZONE_IA"
    GLACIER = "GLACIER"
    # DEEP_ARCHIVE = "DEEP_ARCHIVE"
    # OUTPOSTS = "OUTPOSTS"
    # STANDARD_IA = "STANDARD_IA"


class AWSClientType(StrEnum):
    S3 = "s3"
    SECRETS_MANAGER = "secretsmanager"
    ECS = "ecs"
    EC2 = "ec2"
    BATCH = "batch"


class AsyncAWSClient:
    def __init__(
        self,
        aws_access_key_id: str | None = None,
        aws_secret_access_key: str | None = None,
        region_name: str | None = None,
        endpoint_url: str | None = None,
        profile_name: str | None = None,
    ) -> None:
        self.region_name = region_name or settings.AWS_REGION
        self._endpoint_url = endpoint_url
        self.session = aioboto3.Session(
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
            profile_name=profile_name,
        )

    def _ecs_client(self) -> ECSClient:
        return self.session.client(AWSClientType.ECS, region_name=self.region_name, endpoint_url=self._endpoint_url)

    def _secrets_manager_client(self) -> SecretsManagerClient:
        return self.session.client(
            AWSClientType.SECRETS_MANAGER, region_name=self.region_name, endpoint_url=self._endpoint_url
        )

    def _s3_client(self) -> S3Client:
        return self.session.client(AWSClientType.S3, region_name=self.region_name, endpoint_url=self._endpoint_url)

    def _ec2_client(self) -> EC2Client:
        return self.session.client(AWSClientType.EC2, region_name=self.region_name, endpoint_url=self._endpoint_url)

    def _batch_client(self) -> BatchClient:
        return self.session.client(AWSClientType.BATCH, region_name=self.region_name, endpoint_url=self._endpoint_url)

    def _create_tag_string(self, tags: dict[str, str]) -> str:
        return "&".join([f"{k}={v}" for k, v in tags.items()])

    async def get_secret(self, secret_name: str) -> str | None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager/client/get_secret_value.html
        try:
            async with self._secrets_manager_client() as client:
                response = await client.get_secret_value(SecretId=secret_name)
                return response["SecretString"]
        except Exception as e:
            try:
                error_code = e.response["Error"]["Code"]  # type: ignore
            except Exception:
                error_code = "failed-to-get-error-code"
            LOG.exception("Failed to get secret.", secret_name=secret_name, error_code=error_code)
            return None

    async def create_secret(self, secret_name: str, secret_value: str) -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager/client/create_secret.html
        try:
            async with self._secrets_manager_client() as client:
                await client.create_secret(Name=secret_name, SecretString=secret_value)
        except Exception as e:
            LOG.exception("Failed to create secret.", secret_name=secret_name)
            raise e

    async def set_secret(self, secret_name: str, secret_value: str) -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager/client/put_secret_value.html
        try:
            async with self._secrets_manager_client() as client:
                await client.put_secret_value(SecretId=secret_name, SecretString=secret_value)
        except Exception as e:
            LOG.exception("Failed to set secret.", secret_name=secret_name)
            raise e

    async def delete_secret(self, secret_name: str) -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager/client/delete_secret.html
        try:
            async with self._secrets_manager_client() as client:
                await client.delete_secret(SecretId=secret_name)
        except Exception as e:
            LOG.exception("Failed to delete secret.", secret_name=secret_name)
            raise e

    async def upload_file(
        self,
        uri: str,
        data: bytes,
        storage_class: S3StorageClass = S3StorageClass.STANDARD,
        tags: dict[str, str] | None = None,
    ) -> str | None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/put_object.html
        if storage_class not in S3StorageClass:
            raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
        try:
            async with self._s3_client() as client:
                parsed_uri = S3Uri(uri)
                extra_args = {"Tagging": self._create_tag_string(tags)} if tags else {}
                await client.put_object(
                    Body=data,
                    Bucket=parsed_uri.bucket,
                    Key=parsed_uri.key,
                    StorageClass=str(storage_class),
                    **extra_args,
                )
                return uri
        except Exception:
            LOG.exception("S3 upload failed.", uri=uri)
            return None

    async def upload_file_stream(
        self,
        uri: str,
        file_obj: IO[bytes],
        storage_class: S3StorageClass = S3StorageClass.STANDARD,
        tags: dict[str, str] | None = None,
    ) -> str | None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/upload_fileobj.html#upload-fileobj
        if storage_class not in S3StorageClass:
            raise ValueError(f"Invalid storage class: {storage_class}. Must be one of {list(S3StorageClass)}")
        try:
            async with self._s3_client() as client:
                parsed_uri = S3Uri(uri)
                extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
                if tags:
                    extra_args["Tagging"] = self._create_tag_string(tags)
                await client.upload_fileobj(
                    file_obj,
                    parsed_uri.bucket,
                    parsed_uri.key,
                    ExtraArgs=extra_args,
                )
                LOG.debug("Upload file stream success", uri=uri)
                return uri
        except Exception:
            LOG.exception("S3 upload stream failed.", uri=uri)
            return None

    async def upload_file_from_path(
        self,
        uri: str,
        file_path: str,
        storage_class: S3StorageClass = S3StorageClass.STANDARD,
        metadata: dict | None = None,
        raise_exception: bool = False,
        tags: dict[str, str] | None = None,
    ) -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/upload_file.html
        try:
            async with self._s3_client() as client:
                parsed_uri = S3Uri(uri)
                extra_args: dict[str, Any] = {"StorageClass": str(storage_class)}
                if metadata:
                    extra_args["Metadata"] = metadata
                if tags:
                    extra_args["Tagging"] = self._create_tag_string(tags)
                await client.upload_file(
                    Filename=file_path,
                    Bucket=parsed_uri.bucket,
                    Key=parsed_uri.key,
                    ExtraArgs=extra_args,
                )
        except Exception as e:
            LOG.exception("S3 upload failed.", uri=uri)
            if raise_exception:
                raise e

    async def download_file(self, uri: str, log_exception: bool = True) -> bytes | None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_object.html
        try:
            async with self._s3_client() as client:
                parsed_uri = S3Uri(uri)

                # Get full object including body
                response = await client.get_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
                return await response["Body"].read()
        except Exception:
            if log_exception:
                LOG.exception("S3 download failed", uri=uri)
            return None

    async def delete_file(self, uri: str, log_exception: bool = True) -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/delete_object.html
        try:
            async with self._s3_client() as client:
                parsed_uri = S3Uri(uri)
                await client.delete_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)
        except Exception:
            if log_exception:
                LOG.exception("S3 delete failed", uri=uri)

    async def get_object_info(self, uri: str) -> dict:
        async with self._s3_client() as client:
            parsed_uri = S3Uri(uri)
            # Only get object metadata without the body
            return await client.head_object(Bucket=parsed_uri.bucket, Key=parsed_uri.key)

    async def get_file_metadata(
        self,
        uri: str,
        log_exception: bool = True,
    ) -> dict | None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/head_object.html
        """
        Retrieves only the metadata of a file without downloading its content.

        Args:
            uri: The S3 URI of the file
            client: Optional S3 client to use
            log_exception: Whether to log exceptions

        Returns:
            The metadata dictionary or None if the request fails
        """
        try:
            response = await self.get_object_info(uri)
            return response.get("Metadata", {})
        except Exception:
            if log_exception:
                LOG.exception("S3 metadata retrieval failed", uri=uri)
            return None

    async def create_presigned_urls(self, uris: list[str]) -> list[str] | None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/generate_presigned_url.html
        presigned_urls = []
        try:
            async with self._s3_client() as client:
                for uri in uris:
                    parsed_uri = S3Uri(uri)
                    url = await client.generate_presigned_url(
                        "get_object",
                        Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
                        ExpiresIn=settings.PRESIGNED_URL_EXPIRATION,
                    )
                    presigned_urls.append(url)

                return presigned_urls
        except Exception:
            LOG.exception("Failed to create presigned url for S3 objects.", uris=uris)
            return None

    async def list_files(self, uri: str) -> list[str]:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/paginator/ListObjectsV2.html
        object_keys: list[str] = []
        parsed_uri = S3Uri(uri)
        async with self._s3_client() as client:
            async for page in client.get_paginator("list_objects_v2").paginate(
                Bucket=parsed_uri.bucket, Prefix=parsed_uri.key
            ):
                if "Contents" in page:
                    for obj in page["Contents"]:
                        object_keys.append(obj["Key"])
            return object_keys

    async def delete_files(self, bucket: str, keys: list[str]) -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/delete_objects.html
        """
        Delete multiple objects from S3 bucket.

        Args:
            bucket: The S3 bucket name
            keys: List of object keys to delete
        """
        if not keys:
            return

        try:
            async with self._s3_client() as client:
                # Format the objects for the delete_objects call
                objects = [{"Key": key} for key in keys]

                response = await client.delete_objects(
                    Bucket=bucket,
                    Delete={
                        "Objects": objects,
                        "Quiet": False,  # Set to True to suppress response details
                    },
                )

                # Log any errors that occurred during deletion
                if "Errors" in response:
                    for error in response["Errors"]:
                        LOG.error(
                            "Failed to delete object from S3",
                            bucket=bucket,
                            key=error.get("Key"),
                            code=error.get("Code"),
                            message=error.get("Message"),
                        )
        except Exception as e:
            LOG.exception("Failed to delete files from S3", bucket=bucket, keys_count=len(keys))
            raise e

    async def restore_object(self, bucket: str, key: str, days: int = 1, tier: str = "Standard") -> None:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/restore_object.html
        """
        Restore an archived S3 object from GLACIER storage class.

        Args:
            bucket: The S3 bucket name
            key: The S3 object key
            days: Number of days to keep the restored object available (default: 1)
            tier: Restoration tier - "Standard" (3-5 hours) or "Expedited" (1-5 minutes)
        """
        try:
            async with self._s3_client() as client:
                await client.restore_object(
                    Bucket=bucket, Key=key, RestoreRequest={"Days": days, "GlacierJobParameters": {"Tier": tier}}
                )
        except Exception as e:
            LOG.exception("Failed to restore S3 object", bucket=bucket, key=key, tier=tier)
            raise e

    async def run_task(
        self,
        cluster: str,
        launch_type: str,
        task_definition: str,
        subnets: list[str],
        security_groups: list[str],
        assign_public_ip: str = "DISABLED",
        enable_execute_command: bool = False,
    ) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/run_task.html
        async with self._ecs_client() as client:
            return await client.run_task(
                cluster=cluster,
                launchType=launch_type,
                taskDefinition=task_definition,
                networkConfiguration={
                    "awsvpcConfiguration": {
                        "subnets": subnets,
                        "securityGroups": security_groups,
                        "assignPublicIp": assign_public_ip,
                    }
                },
                enableExecuteCommand=enable_execute_command,
            )

    async def stop_task(self, cluster: str, task: str, reason: str | None = None) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/stop_task.html
        async with self._ecs_client() as client:
            return await client.stop_task(cluster=cluster, task=task, reason=reason)

    async def describe_tasks(self, cluster: str, tasks: list[str]) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/describe_tasks.html
        async with self._ecs_client() as client:
            return await client.describe_tasks(cluster=cluster, tasks=tasks)

    async def list_tasks(self, cluster: str) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/list_tasks.html
        async with self._ecs_client() as client:
            return await client.list_tasks(cluster=cluster)

    async def describe_task_definition(self, task_definition: str) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/describe_task_definition.html
        async with self._ecs_client() as client:
            return await client.describe_task_definition(taskDefinition=task_definition)

    async def deregister_task_definition(self, task_definition: str) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs/client/deregister_task_definition.html
        async with self._ecs_client() as client:
            return await client.deregister_task_definition(taskDefinition=task_definition)

    ###### EC2 ######
    async def describe_network_interfaces(self, network_interface_ids: list[str]) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2/client/describe_network_interfaces.html
        async with self._ec2_client() as client:
            return await client.describe_network_interfaces(NetworkInterfaceIds=network_interface_ids)

    ###### Batch ######
    async def describe_job(self, job_id: str) -> dict:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch/client/describe_jobs.html
        async with self._batch_client() as client:
            response = await client.describe_jobs(jobs=[job_id])
            return response["jobs"][0] if response["jobs"] else {}

    async def list_jobs(self, job_queue: str, job_status: str) -> list[dict]:
        # NOTE: AWS batch only records the latest 7 days jobs by default
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch/client/list_jobs.html
        async with self._batch_client() as client:
            total_jobs = []
            async for page in client.get_paginator("list_jobs").paginate(jobQueue=job_queue, jobStatus=job_status):
                for job in page["jobSummaryList"]:
                    total_jobs.append(job)

            return total_jobs

    async def submit_job(
        self,
        job_name: str,
        job_queue: str,
        job_definition: str,
        params: dict,
        job_priority: int | None = None,
        share_identifier: str | None = None,
        container_overrides: dict | None = None,
        depends_on_ids: list[str] | None = None,
    ) -> str | None:
        container_overrides = container_overrides or {}
        depends_on = [{"jobId": job_id} for job_id in depends_on_ids or []]
        async with self._batch_client() as client:
            if job_priority is None or share_identifier is None:
                response = await client.submit_job(
                    jobName=job_name,
                    jobQueue=job_queue,
                    jobDefinition=job_definition,
                    parameters=params,
                    containerOverrides=container_overrides,
                    dependsOn=depends_on,
                )
                return response.get("jobId")
            else:
                response = await client.submit_job(
                    jobName=job_name,
                    jobQueue=job_queue,
                    jobDefinition=job_definition,
                    parameters=params,
                    schedulingPriorityOverride=job_priority,
                    shareIdentifier=share_identifier,
                    containerOverrides=container_overrides,
                    dependsOn=depends_on,
                )
                return response.get("jobId")


class S3Uri:
    # From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
    """
    >>> s = S3Uri("s3://bucket/hello/world")
    >>> s.bucket
    'bucket'
    >>> s.key
    'hello/world'
    >>> s.uri
    's3://bucket/hello/world'

    >>> s = S3Uri("s3://bucket/hello/world?qwe1=3#ddd")
    >>> s.bucket
    'bucket'
    >>> s.key
    'hello/world?qwe1=3#ddd'
    >>> s.uri
    's3://bucket/hello/world?qwe1=3#ddd'

    >>> s = S3Uri("s3://bucket/hello/world#foo?bar=2")
    >>> s.key
    'hello/world#foo?bar=2'
    >>> s.uri
    's3://bucket/hello/world#foo?bar=2'
    """

    def __init__(self, uri: str) -> None:
        self._parsed = urlparse(uri, allow_fragments=False)

    @property
    def bucket(self) -> str:
        return self._parsed.netloc

    @property
    def key(self) -> str:
        if self._parsed.query:
            return self._parsed.path.lstrip("/") + "?" + self._parsed.query
        else:
            return self._parsed.path.lstrip("/")

    @property
    def uri(self) -> str:
        return self._parsed.geturl()

    def __str__(self) -> str:
        return self.uri


def tag_set_to_dict(tag_set: list[dict[str, str]]) -> dict[str, str]:
    """Convert a list of tags to a dictionary."""
    return {tag["Key"]: tag["Value"] for tag in tag_set}


aws_client = AsyncAWSClient()
